import argparse
import os
import pickle
import torch
import numpy as np
from new_x1_env import BotElfEnv
from rsl_rl.runners import OnPolicyRunner
import genesis as gs
import matplotlib.pyplot as plt

def main():
    # 解析命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--exp_name", type=str, default="new_x1_forward")
    parser.add_argument("--ckpt", type=int, default=2000)
    args = parser.parse_args()

    # 初始化环境
    gs.init()

    # 加载配置文件
    log_dir = f"logs/{args.exp_name}"
    env_cfg, obs_cfg, reward_cfg, command_cfg, train_cfg = pickle.load(open(f"logs/{args.exp_name}/cfgs.pkl", "rb"))
    reward_cfg["reward_scales"] = {}

    # 创建环境
    env = BotElfEnv(
        num_envs=1,
        env_cfg=env_cfg,
        obs_cfg=obs_cfg,
        reward_cfg=reward_cfg,
        command_cfg=command_cfg,
        show_viewer=True,
    )

    # 创建 runner 并加载模型
    runner = OnPolicyRunner(env, train_cfg, log_dir, device="cuda:0")
    resume_path = os.path.join(log_dir, f"model_{args.ckpt}.pt")
    runner.load(resume_path)
    policy = runner.get_inference_policy(device="cuda:0")

    # 重置环境
    obs, _ = env.reset()

    # 初始化存储力的列表
    control_forces = []
    internal_forces = []

    # 设置最大时间步长
    max_steps = 1000
    step_count = 0

    # 运行模拟
    with torch.no_grad():
        while step_count < max_steps:
            actions = policy(obs)
            obs, _, rews, dones, infos = env.step(actions)

            # 获取控制力和实际力
            dofs_idx = env.motor_dofs  # 获取电机关节的索引
            control_force = env.robot.get_dofs_control_force(dofs_idx)
            internal_force = env.robot.get_dofs_force(dofs_idx)
            
            # 将力存储到列表中
            control_forces.append(control_force.cpu().numpy().flatten())  # 展平为 1D 数组
            internal_forces.append(internal_force.cpu().numpy().flatten())  # 展平为 1D 数组
            
            # 打印控制力和实际力
            print(f"Step {step_count + 1}:")
            print("control force:", control_force)
            print("internal force:", internal_force)
            
            # 增加步数计数器
            step_count += 1

    # 将列表转换为 numpy 数组
    control_forces = np.array(control_forces)  # 形状为 (2000, 12)
    internal_forces = np.array(internal_forces)  # 形状为 (2000, 12)

    # 模拟结束后为所有关节绘制图表
    num_joints = control_forces.shape[1]  # 获取关节数量

    # 创建一个包含多个子图的画布
    fig, axes = plt.subplots(num_joints, 1, figsize=(10, 6 * num_joints))  # 每个关节一个子图
    if num_joints == 1:
        axes = [axes]  # 如果只有一个关节，确保 axes 是列表

    for i in range(num_joints):
        ax = axes[i]
        ax.plot(control_forces[:, i], label=f'Control Force (Joint {i + 1})')
        ax.plot(internal_forces[:, i], label=f'Internal Force (Joint {i + 1})')
        ax.set_xlabel('Time Step')
        ax.set_ylabel('Force')
        ax.set_title(f'Control Force vs Internal Force (Joint {i + 1})')
        ax.legend()

    # 调整布局以避免重叠
    plt.tight_layout()

    # 显示所有图表
    plt.show()

if __name__ == "__main__":
    main()